package service

import (
	"bufio"
	"bytes"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"gin-luban-server/global"
	"gin-luban-server/model"
	"github.com/gorilla/websocket"
	"go.uber.org/zap"
	gossh "golang.org/x/crypto/ssh"
	"log"
	"net"
	"regexp"
	"strconv"
	"sync"
	"time"
	"unicode/utf8"
)

type safeBuffer struct {
	buffer bytes.Buffer
	mu     sync.Mutex
}

func (w *safeBuffer) Write(p []byte) (int, error) {
	w.mu.Lock()
	defer w.mu.Unlock()
	return w.buffer.Write(p)
}
func (w *safeBuffer) Bytes() []byte {
	w.mu.Lock()
	defer w.mu.Unlock()
	return w.buffer.Bytes()
}
func (w *safeBuffer) Reset() {
	w.mu.Lock()
	defer w.mu.Unlock()
	w.buffer.Reset()
}


const (
	wsMsgCmd    = "cmd"
	wsMsgResize = "resize"
)

type Terminal struct {
	Columns uint32 `json:"cols"`
	Rows    uint32 `json:"rows"`
}


type wsMsg struct {
	Type string `json:"type"`
	Cmd  string `json:"cmd"`
	Cols int    `json:"cols"`
	Rows int    `json:"rows"`
}


type ptyRequestMsg struct {
	Term     string
	Columns  uint32
	Rows     uint32
	Width    uint32
	Height   uint32
	Modelist string
}

type SshClientInfo struct {
	Username  string `json:"username"`
	SshUser   string `json:"ssh_user"`
	Password  string `json:"password"`
	IpAddress string `json:"ipaddress"`
	Port      int    `json:"port"`
	sshType   string `json:"sshType"`
	sshKey    string `json:"sshKey"`
	isAdmin   bool
	isProxy   bool
	ProxySshType string  `json:"proxy_sshType"`
	ProxyUser    string  `json:"proxy_user"`
	ProxySshKey  string  `json:"proxy_sshKey"`
	ProxyPassword string `json:"proxy_password"`
	ProxyHost     string `json:"proxy_host"`
	ProxyPort     int    `json:"proxy_port"`
}

type SSHClient struct {
	Username  string `json:"username"`
	SshUser   string `json:"ssh_user"`
	IpAddress string `json:"ipaddress"`
	Port      int    `json:"port"`
	filterBuff        *safeBuffer
	Session           *gossh.Session
	Client            *gossh.Client
	channel           gossh.Channel
	StartTime         time.Time
	isAdmin           bool
	IsFlagged         bool `comment:"当前session是否包含禁止命令"`
	HasEditor         bool
	sshFilters        model.JsonArraySshFilter
	LogsTerm          model.JsonArrayString
}
//@author: heyibo
//@function: GenerateClient
//@description: 生成client文件
//@return:
func GenerateClient(clientInfo SshClientInfo)(sshClient SSHClient,err error)  {
	sshClient.Username = clientInfo.Username
	sshClient.SshUser  = clientInfo.SshUser
	sshClient.IpAddress = clientInfo.IpAddress
	sshClient.Port =  clientInfo.Port
	sshClient.isAdmin = clientInfo.isAdmin
	sshClient.sshFilters = MustSshFilterGroup().Filters
	sshClient.StartTime = time.Now()
	if clientInfo.isProxy {
		proxyConfig,err :=NewSshClientConfig(clientInfo.ProxyUser,clientInfo.ProxyPassword,clientInfo.ProxySshType,clientInfo.ProxySshKey)
		if err !=nil {
			return sshClient,errors.New("代理配置文件出错！")
		}
		proxyAddr := fmt.Sprintf("%s:%d", clientInfo.ProxyHost, clientInfo.ProxyPort)
		targetConfig, err :=NewSshClientConfig(clientInfo.SshUser,clientInfo.Password,clientInfo.sshType,clientInfo.sshKey)
		targetAddr :=fmt.Sprintf("%s:%d", clientInfo.IpAddress, clientInfo.Port)
		client,err :=NewSshProxyClient(targetConfig,proxyConfig,targetAddr,proxyAddr)
		if err !=nil {
			return sshClient,err
		}
		sshClient.Client = client
	}else {
		config := gossh.Config{
			Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "aes192-cbc", "aes256-cbc"},
		}
		clientConfig := &gossh.ClientConfig{
			User:    clientInfo.SshUser,
			Timeout: 300 * time.Second,
			Config:  config,
			HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error {
				return nil
			},
		}
		switch clientInfo.sshType {
		case "password":
			clientConfig.Auth = []gossh.AuthMethod{gossh.Password(clientInfo.Password)}
		case "key":
			signer, _ := gossh.ParsePrivateKey([]byte(clientInfo.sshKey))
			clientConfig.Auth = []gossh.AuthMethod{gossh.PublicKeys(signer)}
		default:
			return  sshClient,fmt.Errorf("unknow ssh auth type: %s", clientInfo.sshType)
		}
		addr := fmt.Sprintf("%s:%d", clientInfo.IpAddress, clientInfo.Port)

		if client, err := gossh.Dial("tcp", addr, clientConfig); err != nil {
			return sshClient,err
		}else {
			sshClient.Client = client
		}
	}
	return sshClient,err
}

//@author: heyibo
//@function: RequestTerminal
//@description: 生成Terminal
//@return:
func (c *SSHClient) RequestTerminal(terminal Terminal) *SSHClient {
	session, err := c.Client.NewSession()
	if err != nil {
		return nil
	}
	c.Session = session
	inputBuf := new(safeBuffer)
	c.filterBuff = inputBuf
	//stdinP, err := session.StdinPipe()
	if err != nil {
		return nil
	}
	//ssh.stdout and stderr will write output into comboWriter
	channel, inRequests, err := c.Client.OpenChannel("session", nil)
	if err != nil {
		log.Println(err)
		return nil
	}
	c.channel = channel
	go func() {
		for req := range inRequests {
			if req.WantReply {
				req.Reply(false, nil)
			}
		}
	}()
	modes := gossh.TerminalModes{
		gossh.ECHO:          1,
		gossh.TTY_OP_ISPEED: 14400,
		gossh.TTY_OP_OSPEED: 14400,
	}
	var modeList []byte
	for k, v := range modes {
		kv := struct {
			Key byte
			Val uint32
		}{k, v}
		modeList = append(modeList, gossh.Marshal(&kv)...)
	}
	modeList = append(modeList, 0)
	req := ptyRequestMsg{
		Term:     "xterm",
		Columns:  terminal.Columns,
		Rows:     terminal.Rows,
		Width:    uint32(terminal.Columns * 8),
		Height:   uint32(terminal.Columns * 8),
		Modelist: string(modeList),
	}
	ok, err := channel.SendRequest("pty-req", true, gossh.Marshal(&req))
	if !ok || err != nil {
		log.Println(err)
		return nil
	}
	ok, err = channel.SendRequest("shell", true, nil)
	if !ok || err != nil {
		log.Println(err)
		return nil
	}
	return c
}
//@author: heyibo
//@function: Connect
//@description: 命令拦截
//@return:
func (c *SSHClient) Connect(ws *websocket.Conn) {
	//这里第一个协程获取用户的输入
	done := make(chan bool, 2)
	go func() {
		defer func() {
			done <- true
		}()
		for {
			// p为用户输入
			if c.isAdmin {
				_, wsData, err := ws.ReadMessage()
				if err != nil {
					global.GVA_LOG.Error("reading webSocket message failed", zap.Any("err", err))
					return
				}
				//unmashal bytes into struct
				msgObj := wsMsg{}
				if err := json.Unmarshal(wsData, &msgObj); err != nil {
					global.GVA_LOG.Error("unmarshal websocket message failed", zap.Any("err", err))
				}
				switch msgObj.Type {
				case wsMsgResize:
					//handle xterm.js size change
					if msgObj.Cols > 0 && msgObj.Rows > 0 {
						if err := c.Session.WindowChange(msgObj.Rows, msgObj.Cols); err != nil {
							global.GVA_LOG.Error("ssh pty change windows size failed", zap.Any("err", err))
						}
					}
				case wsMsgCmd:
					//handle xterm.js stdin
					decodeBytes, err := base64.StdEncoding.DecodeString(msgObj.Cmd)
					if err != nil {
						global.GVA_LOG.Error("websock cmd string base64 decoding failed", zap.Any("err", err))
					}
					//命令过滤
					var lineCommand []byte
					for _, bb := range decodeBytes {
							//判断命令是否开始换行或者;
							if bb == '\r' || bb == ';' || bb == '\n' {
								lineCommand = c.filterBuff.Bytes()
								c.filterBuff.Reset()
								//匹配配置的命令策略
							} else {
								_, err := c.filterBuff.Write([]byte{bb})
								if err != nil {
									global.GVA_LOG.Error("sws.inputFilterBuff.Write", zap.Any("err", err))
								}
							}
						}
					if len(lineCommand) > 0 {
						isEditor, err := regexp.Match(`\b(vim|vi|nano|emacs|gedit|kate|kedit)\b`, lineCommand)
						if err != nil {
							global.GVA_LOG.Error("检测文本编辑器失败")
						}
						if isEditor {
							c.HasEditor = true
						}
					}
					if len(lineCommand) > 0 {
						for _, rule := range c.sshFilters {
							patern := rule.Command
							isMatch, err := regexp.Match(patern, lineCommand)
							if err != nil {
								global.GVA_LOG.Error("regexp.Match(patern,rawCmdB)", zap.Any("err", err))
							}
							if isMatch {
								c.IsFlagged = true
								//write warning msg into websocket terminal
								warning := fmt.Sprintf("\n\r \033[0;31m%s\033[0m\r\n", rule.Msg)
								ws.WriteMessage(websocket.TextMessage, []byte(warning))
								decodeBytes = []byte{byte(025)}
							}
						}
					}
					if _, err := c.channel.Write(decodeBytes); err != nil {
						global.GVA_LOG.Error("ws cmd bytes write to ssh.stdin pipe failed", zap.Any("err", err))
					}
				}
			}
		}
	}()
	//第二个协程将远程主机的返回结果返回给用户
	go func() {
		defer func() {
			done <- true
		}()
		br := bufio.NewReader(c.channel)
		buf := []byte{}
		t := time.NewTimer(time.Microsecond * 100)
		defer t.Stop()
		// 构建一个信道, 一端将数据远程主机的数据写入, 一段读取数据写入ws
		r := make(chan rune)
		//wsh.Mux.Lock()
		// 另起一个协程, 一个死循环不断的读取ssh channel的数据, 并传给r信道直到连接断开
		go func() {
			for {
				x, size, err := br.ReadRune()
				if err != nil {
					log.Println(err)
					//ws.WriteMessage(1, []byte("\033[31m已经关闭连接!\033[0m"))
					//c.addWriteMessageLog([]byte("\033[31m已经关闭连接!\033[0m"));
					ws.WriteMessage(1, []byte("已经关闭连接!"))
					//ws.WriteMessage(1, []byte("logout"))
					ws.Close()
					return
					//wsh.Mux.Unlock()
				}
				if size > 0 {
					r <- x
				}
			}
		}()

		// 主循环
		for {
			select {
			// 每隔100微秒, 只要buf的长度不为0就将数据写入ws, 并重置时间和buf
			case <-t.C:
				if len(buf) != 0 {
					//c.addWriteMessageLog(buf)
					err := ws.WriteMessage(websocket.TextMessage, buf)
					c.WriteLog(buf)
					buf = []byte{}
					if err != nil {
						log.Println(err)
						return
					}
				}
				t.Reset(time.Microsecond * 100)

			//前面已经将ssh channel里读取的数据写入创建的通道r, 这里读取数据, 不断增加buf的长度, 在设定的 100 microsecond后由上面判定长度是否返送数据
			case d := <-r:
				if d != utf8.RuneError {
					p := make([]byte, utf8.RuneLen(d))
					utf8.EncodeRune(p, d)
					buf = append(buf, p...)
				} else {
					buf = append(buf, []byte("@")...)
				}
			}
		}
	}()
	defer func() {
		c.channel.Close()
		c.Client.Close()
	}()

	<-done

	jumpLogs :=model.JumpServerSshLogs{SshUser: c.SshUser,UserName: c.Username,ClientIp: c.IpAddress,SshPort:strconv.Itoa(c.Port),StartedAt: c.StartTime,Remark: "日志信息"}
	if c.IsFlagged {
		jumpLogs.Status =8
	}else if c.HasEditor {
		jumpLogs.Status =32
	}else {
		jumpLogs.Status =1
	}
	webTerLogs, _ := json.Marshal(c.LogsTerm)
	jumpLogs.WebClientLogs = string(webTerLogs)
	err :=CreateJumpServerSshLogs(jumpLogs)
	if err !=nil{
		fmt.Println("日志记录出错",err)
	}
}
//@author: heyibo
//@function: WriteLog
//@description: 保存日志到内存中
//@return:
func (c *SSHClient) WriteLog(bs []byte) {
	if len(bs) > 0 {
		c.LogsTerm = append(c.LogsTerm, string(bs))
	}
}

